{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 2.1.0\n",
      "\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "loading embedding ...\n",
      "embedding loaded as `glove_embedding`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 5365.33it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:05<00:00, 3205.80it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 310569.71it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 349915.35it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 3031577.24it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 8384.04it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:04<00:00, 3939.48it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 125562.34it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 96621.02it/s]\n",
      "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:07<00:00, 269.84it/s]\n",
      "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 18841/18841 [01:27<00:00, 216.37it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 122/122 [00:00<00:00, 7746.89it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 1115/1115 [00:14<00:00, 77.28it/s] \n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 63447.62it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 88946.88it/s]\n",
      "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 273.77it/s]\n",
      "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:04<00:00, 226.13it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 237/237 [00:00<00:00, 8707.90it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 2237.22it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 101299.30it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 97484.78it/s]\n",
      "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 269.12it/s]\n",
      "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:10<00:00, 212.35it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.CDSSMPreprocessor(fixed_length_left=10, fixed_length_right=10)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n",
    "valid_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10, 9644)     0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 10, 9644)     0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_1 (Conv1D)               (None, 10, 64)       1851712     text_left[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_2 (Conv1D)               (None, 10, 64)       1851712     text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 10, 64)       0           conv1d_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "dropout_2 (Dropout)             (None, 10, 64)       0           conv1d_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "global_max_pooling1d_1 (GlobalM (None, 64)           0           dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "global_max_pooling1d_2 (GlobalM (None, 64)           0           dropout_2[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 64)           4160        global_max_pooling1d_1[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 64)           4160        global_max_pooling1d_2[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 64)           4160        dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 64)           4160        dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 1)            0           dense_2[0][0]                    \n",
      "                                                                 dense_4[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_5 (Dense)                 (None, 1)            2           dot_1[0][0]                      \n",
      "==================================================================================================\n",
      "Total params: 3,720,066\n",
      "Trainable params: 3,720,066\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.CDSSM()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['filters'] = 64\n",
    "model.params['kernel_size'] = 3\n",
    "model.params['strides'] = 1\n",
    "model.params['padding'] = 'same'\n",
    "model.params['conv_activation_func'] = 'tanh'\n",
    "model.params['w_initializer'] = 'glorot_normal'\n",
    "model.params['b_initializer'] = 'zeros'\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 64\n",
    "model.params['mlp_num_fan_out'] = 64\n",
    "model.params['mlp_activation_func'] = 'tanh'\n",
    "model.params['dropout_rate'] = 0.8\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num batches: 102\n"
     ]
    }
   ],
   "source": [
    "pred_x, pred_y = test_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))\n",
    "train_generator = mz.DataGenerator(\n",
    "    train_pack_processed,\n",
    "    mode='pair',\n",
    "    num_dup=2,\n",
    "    num_neg=1,\n",
    "    batch_size=20\n",
    ")\n",
    "print('num batches:', len(train_generator))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "102/102 [==============================] - 65s 635ms/step - loss: 0.8021\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.42304382290227854 - normalized_discounted_cumulative_gain@5(0.0): 0.49915948768338086 - mean_average_precision(0.0): 0.46037758752542035\n",
      "Epoch 2/20\n",
      "102/102 [==============================] - 45s 445ms/step - loss: 0.5966\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.43781763271520285 - normalized_discounted_cumulative_gain@5(0.0): 0.520097599085372 - mean_average_precision(0.0): 0.4762598411822459\n",
      "Epoch 3/20\n",
      "102/102 [==============================] - 46s 447ms/step - loss: 0.4992\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44923101748788413 - normalized_discounted_cumulative_gain@5(0.0): 0.5136672947113214 - mean_average_precision(0.0): 0.4803110559647868\n",
      "Epoch 4/20\n",
      "102/102 [==============================] - 46s 446ms/step - loss: 0.4143\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.467714954371615 - normalized_discounted_cumulative_gain@5(0.0): 0.5353130653753986 - mean_average_precision(0.0): 0.5017560318255102\n",
      "Epoch 5/20\n",
      "102/102 [==============================] - 46s 451ms/step - loss: 0.3489\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4706511291875292 - normalized_discounted_cumulative_gain@5(0.0): 0.5275832328072992 - mean_average_precision(0.0): 0.5026243479583462\n",
      "Epoch 6/20\n",
      "102/102 [==============================] - 45s 443ms/step - loss: 0.3231\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4641151831570107 - normalized_discounted_cumulative_gain@5(0.0): 0.5219564667466021 - mean_average_precision(0.0): 0.4934049132027672\n",
      "Epoch 7/20\n",
      "102/102 [==============================] - 44s 433ms/step - loss: 0.2695\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4760514687477512 - normalized_discounted_cumulative_gain@5(0.0): 0.5285019348702702 - mean_average_precision(0.0): 0.49994736585416333\n",
      "Epoch 8/20\n",
      "102/102 [==============================] - 47s 457ms/step - loss: 0.2331\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4661781341235399 - normalized_discounted_cumulative_gain@5(0.0): 0.5260071867435453 - mean_average_precision(0.0): 0.4922605321622356\n",
      "Epoch 9/20\n",
      "102/102 [==============================] - 46s 453ms/step - loss: 0.1942\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4645719968337067 - normalized_discounted_cumulative_gain@5(0.0): 0.5238558790194195 - mean_average_precision(0.0): 0.48851468090847294\n",
      "Epoch 10/20\n",
      "102/102 [==============================] - 45s 444ms/step - loss: 0.1734\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4600910137969285 - normalized_discounted_cumulative_gain@5(0.0): 0.5320923473092672 - mean_average_precision(0.0): 0.48703092961044614\n",
      "Epoch 11/20\n",
      "102/102 [==============================] - 47s 464ms/step - loss: 0.1644\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45786306386326225 - normalized_discounted_cumulative_gain@5(0.0): 0.5246949873542252 - mean_average_precision(0.0): 0.48502089087514016\n",
      "Epoch 12/20\n",
      "102/102 [==============================] - 45s 443ms/step - loss: 0.1560\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4567332855642369 - normalized_discounted_cumulative_gain@5(0.0): 0.528074374789356 - mean_average_precision(0.0): 0.4905494464640722\n",
      "Epoch 13/20\n",
      "102/102 [==============================] - 45s 440ms/step - loss: 0.1365\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4698836510431016 - normalized_discounted_cumulative_gain@5(0.0): 0.5317255666034969 - mean_average_precision(0.0): 0.49152222966181813\n",
      "Epoch 14/20\n",
      "102/102 [==============================] - 45s 437ms/step - loss: 0.1263\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.46841048236088156 - normalized_discounted_cumulative_gain@5(0.0): 0.5197949838102164 - mean_average_precision(0.0): 0.4887341126171474\n",
      "Epoch 15/20\n",
      "102/102 [==============================] - 46s 449ms/step - loss: 0.1208\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4591952265806063 - normalized_discounted_cumulative_gain@5(0.0): 0.5306329604843507 - mean_average_precision(0.0): 0.4956899590808506\n",
      "Epoch 16/20\n",
      "102/102 [==============================] - 44s 430ms/step - loss: 0.0977\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4698790408918565 - normalized_discounted_cumulative_gain@5(0.0): 0.5355447042513717 - mean_average_precision(0.0): 0.5005823464725863\n",
      "Epoch 17/20\n",
      "102/102 [==============================] - 44s 436ms/step - loss: 0.0975\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4699380823064665 - normalized_discounted_cumulative_gain@5(0.0): 0.5335843828018585 - mean_average_precision(0.0): 0.4945873841691485\n",
      "Epoch 18/20\n",
      "102/102 [==============================] - 44s 431ms/step - loss: 0.1070\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4558123512520484 - normalized_discounted_cumulative_gain@5(0.0): 0.5280824209271964 - mean_average_precision(0.0): 0.49009730599920476\n",
      "Epoch 19/20\n",
      "102/102 [==============================] - 45s 446ms/step - loss: 0.0846\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.47468488947984894 - normalized_discounted_cumulative_gain@5(0.0): 0.5462940161373581 - mean_average_precision(0.0): 0.5050693971440435\n",
      "Epoch 20/20\n",
      "102/102 [==============================] - 43s 425ms/step - loss: 0.0833\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4571219999869289 - normalized_discounted_cumulative_gain@5(0.0): 0.527098973778668 - mean_average_precision(0.0): 0.4884211445807594\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=20, callbacks=[evaluate], workers=1, use_multiprocessing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
